#include "StdAfx.h"
#include "PolicyFileWriter.h"
#include "HuginApiAddons.h"

//////////////////////////////////////////////////////////////////////////
//Author	:	Ross Conroy ross.conroy@tees.ac.uk
//Date		:	13/05/2014
//
//This class writes the policy of a domain to a file or files per action.
//The format of the files is as follows
//Node Name,Action1,...,ActionN
//Parent Name 1,Parent State 1,...,Parent StateN
//...
//Parent Name n,Parent State 1,...,Parent StateN
//Parent 1 Action ID,...,ParentN Action ID,Outcome 1,...,OutcomeN
//
//Example Output Single File
//NODE
//Action0,Attack,Escape
//PARENTS
//Observation0,See,DontSee
//POLICY
//See,Attack
//DontSee,Escape
//
//NODE
//Action1,Attack,Escape
//PARENTS
//Observation0,See,DontSee
//Observation1,See,DontSee
//Action0,Attack,Escape
//POLICY
//See,See,Attack,
//See,See,Escape,
//See,DontSee,Attack,
//See,DontSee,Escape,
//etc.......
//
//Can be one file for all nodes e.g. policy.txt or individual files for
//each action where the name of each action node name is the file name
//e.g. Action0.txt
//
//Update	:	15/09/2014
//Author	:	Ross Conroy ross.conroy@tees.ac.uk
//Change to mimic the policy tree builder for generating policy files
//to use file csv format
//////////////////////////////////////////////////////////////////////////

PolicyFileWriter::PolicyFileWriter(void)
{
}

bool PolicyFileWriter::WritePolicyToFile(Domain * inputDID, string actionPrefix, string observationPrefix, string outputFolder, string fileExt)
{
	//Get starting nodes
	inputDID->compile();	

	ExpandTimeSteps expander;
	int timeStep = expander.FindFirstTimeStep(inputDID);
	int numSteps = expander.FindLastTimeStep(inputDID);

	vector<DiscreteDecisionNode*> pastDecsisions;
	vector<DiscreteChanceNode*> pastObservations;

	for(int t = timeStep; t <= numSteps; t++)
	{
		hash_map<string, list<string>> policy;

		stringstream ss;
		ss << actionPrefix << t;
		DiscreteDecisionNode * decNode = (DiscreteDecisionNode*)inputDID->getNodeByName(ss.str());
		ss.str("");
		ss.clear();

		if(t > timeStep)
		{
			ss << observationPrefix << t;
			DiscreteChanceNode * ObsNode = (DiscreteChanceNode*)inputDID->getNodeByName(ss.str());	
			ss.str("");
			ss.clear();
			pastObservations.push_back(ObsNode);
		}		

		//Build a list of state numbers for combination generator
		list<int> stateSizes;
		int j = 0;
		int numPast = pastDecsisions.size() + pastObservations.size();
		for(int c = 0; c < numPast; c++)
		{
			//Even - action
			if((c % 2) == 0)
			{
				stateSizes.push_back(pastDecsisions[j]->getNumberOfStates());
			}
			else //odd - observation
			{
				stateSizes.push_back(pastObservations[j]->getNumberOfStates());
				j++;
			}
		}

		if(stateSizes.size() > 0)
		{
			//generate combinations
			CombinationGenerator combi;
			list<list<int>> combinations = combi.GenerateCombinations(stateSizes);

			//loop through each combination and find best action for each
			for(list<list<int>>::iterator it = combinations.begin(); it != combinations.end(); it++)
			{
				inputDID->retractFindings();
				j=0;
				list<int> combination = *it;
				list<int>::iterator itc = combination.begin();

				stringstream ssc;

				for(int c = 0; c < combination.size(); c++)
				{
					DiscreteNode * tempNode;
					if((c % 2) == 0)
					{
						tempNode = pastDecsisions[j];
					}
					else //odd - observation
					{
						tempNode = pastObservations[j];
						j++;
					}

					tempNode->selectState(*itc);

					if(c == 0)
					{
						ssc << tempNode->getStateLabel(*itc);
					}
					else
					{
						ssc << "-" << tempNode->getStateLabel(*itc);
					}

					itc++;
				}

				inputDID->propagate();
				string searchString = ssc.str();
				list<string> actions = GetBestDecision(decNode);
				policy.insert(make_pair(searchString, actions));
				//policy[searchString] = action;
			}
		}
		else
		{
			policy.insert(make_pair("", GetBestDecision(decNode)));
		}

		pastDecsisions.push_back(decNode);

		WriteToFile(outputFolder, decNode->getName(), fileExt, policy);
	}

	return true;
}

void PolicyFileWriter::WriteToFile(string folder, string nodeName, string extension, hash_map<string, list<string>> policy)
{
	stringstream ss;
	ss << folder << nodeName << extension;
	string outF = ss.str();

	ofstream outputFile(outF.c_str());
	outputFile << "SearchString,Action," << endl;

	for(hash_map<string, list<string>>::iterator it = policy.begin(); it != policy.end(); it++)
	{
		stringstream ssrow;
		ssrow << it->first << ",";// << it->second << ",";

		for(list<string>::iterator its = it->second.begin(); its != it->second.end(); its++)
		{
			ssrow << *its << "-";
		}

		ssrow << ",";

		outputFile << ssrow.str() << endl;
	}

	outputFile.close();
}


list<string> PolicyFileWriter::GetBestDecision(DiscreteDecisionNode* node)
{
	double tempUtil;
	vector<string> maxDecision;

	for(int i = 0; i < node->getNumberOfStates(); i++)
	{
		if(i == 0)
		{
			tempUtil = node->getExpectedUtility(i);
			maxDecision.push_back(node->getStateLabel(i));
		}
		else
		{
			if(node->getExpectedUtility(i) > tempUtil)
			{
				maxDecision.clear();
				tempUtil = node->getExpectedUtility(i);
				maxDecision.push_back(node->getStateLabel(i));
			}
			else if(node->getExpectedUtility(i) == tempUtil)
			{
				maxDecision.push_back(node->getStateLabel(i));
			}
		}
	}

	//Shrink to single decision by choosing one at random
	string randomDec = maxDecision[rand() % maxDecision.size()];
	maxDecision.push_back(randomDec);

	list<string> decisions;
	decisions.push_back(randomDec);

	return decisions;
}